{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.0"
    },
    "colab": {
      "name": "segmentation-tutorial.ipynb",
      "provenance": []
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FRWGyaT2Y25j"
      },
      "source": [
        "# Catalyst segmentation tutorial\n",
        "\n",
        "Authors: Dmitry Bleklov, Roman Tezikov\n",
        "\n",
        "### Colab extras\n",
        "\n",
        "First of all, do not forget to change the runtime type to GPU. <br/>\n",
        "To do so click `Runtime` -> `Change runtime type` -> Select `\"Python 3\"` and `\"GPU\"` -> click `Save`. <br/>\n",
        "After that you can click `Runtime` -> `Run` all and watch the tutorial.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hHJAs8U5Y25m"
      },
      "source": [
        "\n",
        "## Requirements\n",
        "\n",
        "Download and install the latest version of catalyst and other libraries required for this tutorial."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
        "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
        "colab_type": "code",
        "id": "xCoyLtaeY25m",
        "scrolled": true,
        "colab": {}
      },
      "source": [
        "!pip install -U catalyst\n",
        "\n",
        "# this variable will be used in `runner.train` and by default we disable FP16 mode\n",
        "is_fp16_used = False\n",
        "\n",
        "# for augmentations\n",
        "!pip install -U albumentations\n",
        "\n",
        "# for TTA\n",
        "!pip install ttach\n",
        "\n",
        "# for tensorboard\n",
        "!pip install tensorflow\n",
        "%load_ext tensorboard\n",
        "\n",
        "\n",
        "# if Your machine doesn't support FP16, comment this 3 lines below\n",
        "!git clone https://github.com/NVIDIA/apex\n",
        "!pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./apex\n",
        "is_fp16_used = True"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MncoA-G0Y25p"
      },
      "source": [
        "## Setting up GPUs\n",
        "PyTorch and Catalyst versions:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8H65wGVbY25q",
        "scrolled": true,
        "colab": {}
      },
      "source": [
        "import torch, catalyst\n",
        "\n",
        "torch.__version__, catalyst.__version__"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6DDk9JviY25r"
      },
      "source": [
        "You can also specify GPU/CPU usage for this turorial.\n",
        "\n",
        "Available GPUs"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ibixGTKvY25s",
        "colab": {}
      },
      "source": [
        "from catalyst.utils import get_available_gpus\n",
        "\n",
        "get_available_gpus()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "cXVyQyDsY25u",
        "colab": {}
      },
      "source": [
        "import os\n",
        "from typing import List, Tuple, Callable\n",
        "\n",
        "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"  # \"\" - CPU, \"0\" - 1 GPU, \"0,1\" - MultiGPU"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0GdMd4MdY25w"
      },
      "source": [
        "---\n",
        "\n",
        "## Reproducibility first\n",
        "\n",
        "Let's set our seed and set the CUDA settings to deterministic mode."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "RpthPML1Y25x",
        "colab": {}
      },
      "source": [
        "from catalyst.utils import set_global_seed, prepare_cudnn\n",
        "\n",
        "SEED = 42\n",
        "\n",
        "set_global_seed(SEED)\n",
        "prepare_cudnn(deterministic=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "nWseoJqWY25z"
      },
      "source": [
        "## Dataset\n",
        "\n",
        "As a data set we will take Carvana - binary segmentation for the \"car\" class."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "34H0dXBYZepr"
      },
      "source": [
        "If you are on MacOS and you don’t have `wget`, you can install it with: `brew install wget`.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8H8tDrZ6Y250",
        "colab": {}
      },
      "source": [
        "%%bash\n",
        "\n",
        "function gdrive_download () {\n",
        "  CONFIRM=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate \"https://docs.google.com/uc?export=download&id=$1\" -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')\n",
        "  wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$CONFIRM&id=$1\" -O $2\n",
        "  rm -rf /tmp/cookies.txt\n",
        "}\n",
        "\n",
        "gdrive_download 1iYaNijLmzsrMlAdMoUEhhJuo-5bkeAuj segmentation_data.zip\n",
        "\n",
        "unzip segmentation_data.zip &>/dev/null"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "g4Vqm9FzY254",
        "colab": {}
      },
      "source": [
        "from pathlib import Path\n",
        "\n",
        "ROOT = Path(\"segmentation_data/\")\n",
        "\n",
        "train_image_path = ROOT / \"train\"\n",
        "train_mask_path = ROOT / \"train_masks\"\n",
        "test_image_path = ROOT / \"test\""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vL0sd6__M_98"
      },
      "source": [
        "Collect images and masks into variables."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wU9IPpyAfhOy",
        "colab": {}
      },
      "source": [
        "ALL_IMAGES = sorted(train_image_path.glob(\"*.jpg\"))\n",
        "len(ALL_IMAGES)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ElvJQrlOf4Gm",
        "colab": {}
      },
      "source": [
        "ALL_MASKS = sorted(train_mask_path.glob(\"*.gif\"))\n",
        "len(ALL_MASKS)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "6mcVBd_WjI43",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "\n",
        "from skimage.io import imread as gif_imread\n",
        "from catalyst import utils\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import random\n",
        "\n",
        "\n",
        "def show_examples(name: str, image: np.ndarray, mask: np.ndarray):\n",
        "    plt.figure(figsize=(10, 14))\n",
        "    plt.subplot(1, 2, 1)\n",
        "    plt.imshow(image)\n",
        "    plt.title(f\"Image: {name}\")\n",
        "\n",
        "    plt.subplot(1, 2, 2)\n",
        "    plt.imshow(mask)\n",
        "    plt.title(f\"Mask: {name}\")\n",
        "\n",
        "\n",
        "def show(index: int, images: List[Path], masks: List[Path], transforms=None) -> None:\n",
        "    image_path = images[index]\n",
        "    name = image_path.name\n",
        "\n",
        "    image = utils.imread(image_path)\n",
        "    mask = gif_imread(masks[index])\n",
        "\n",
        "    if transforms is not None:\n",
        "        temp = transforms(image=image, mask=mask)\n",
        "        image = temp[\"image\"]\n",
        "        mask = temp[\"mask\"]\n",
        "\n",
        "    show_examples(name, image, mask)\n",
        "\n",
        "def show_random(images: List[Path], masks: List[Path], transforms=None) -> None:\n",
        "    length = len(images)\n",
        "    index = random.randint(0, length - 1)\n",
        "    show(index, images, masks, transforms)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "f0zVSTYtk8hf"
      },
      "source": [
        "You can restart the cell below to see more examples."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "r0EZVF1pk3tC",
        "colab": {}
      },
      "source": [
        "show_random(ALL_IMAGES, ALL_MASKS)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JpEVELsNNMTO"
      },
      "source": [
        "The dataset below reads images and masks and optionally applies augmentation to them."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Grrv0FqpY25-",
        "colab": {}
      },
      "source": [
        "from typing import List\n",
        "\n",
        "from torch.utils.data import Dataset\n",
        "\n",
        "\n",
        "class SegmentationDataset(Dataset):\n",
        "    def __init__(\n",
        "        self,\n",
        "        images: List[Path],\n",
        "        masks: List[Path] = None,\n",
        "        transforms=None\n",
        "    ) -> None:\n",
        "        self.images = images\n",
        "        self.masks = masks\n",
        "        self.transforms = transforms\n",
        "\n",
        "    def __len__(self) -> int:\n",
        "        return len(self.images)\n",
        "\n",
        "    def __getitem__(self, idx: int) -> dict:\n",
        "        image_path = self.images[idx]\n",
        "        image = utils.imread(image_path)\n",
        "        \n",
        "        result = {\"image\": image}\n",
        "        \n",
        "        if self.masks is not None:\n",
        "          mask = gif_imread(self.masks[idx])\n",
        "          result[\"mask\"] = mask\n",
        "        \n",
        "        if self.transforms is not None:\n",
        "            result = self.transforms(**result)\n",
        "        \n",
        "        result[\"filename\"] = image_path.name\n",
        "\n",
        "        return result"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DsI3ZS2asQqg"
      },
      "source": [
        "### Augmentations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RF0wtzsjNZQ5"
      },
      "source": [
        "The [albumentation](https://github.com/albu/albumentations) library works with images and masks at the same time, which is what we need."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "zNdK5P0UY26A",
        "colab": {}
      },
      "source": [
        "import albumentations as A\n",
        "from albumentations.pytorch import ToTensor\n",
        "\n",
        "\n",
        "def pre_transforms(image_size=224):\n",
        "    return [A.Resize(image_size, image_size, p=1)]\n",
        "  \n",
        "\n",
        "def resize_transforms(image_size=224):\n",
        "    BORDER_CONSTANT = 0\n",
        "    pre_size = int(image_size * 1.5)\n",
        "\n",
        "    random_crop = A.Compose([\n",
        "      A.SmallestMaxSize(pre_size, p=1),\n",
        "      A.RandomCrop(\n",
        "          image_size, image_size, p=1\n",
        "      )\n",
        "\n",
        "    ])\n",
        "\n",
        "    rescale = A.Compose([A.Resize(image_size, image_size, p=1)])\n",
        "\n",
        "    random_crop_big = A.Compose([\n",
        "      A.LongestMaxSize(pre_size, p=1),\n",
        "      A.RandomCrop(\n",
        "          image_size, image_size, p=1\n",
        "      )\n",
        "\n",
        "    ])\n",
        "\n",
        "    # Converts the image to a square of size image_size x image_size\n",
        "    result = [\n",
        "      A.OneOf([\n",
        "          random_crop,\n",
        "          rescale,\n",
        "          random_crop_big\n",
        "      ], p=1)\n",
        "    ]\n",
        "\n",
        "    return result\n",
        "  \n",
        "def post_transforms():\n",
        "    # we use ImageNet image normalization\n",
        "    # and convert it to torch.Tensor\n",
        "    return [A.Normalize(), ToTensor()]\n",
        "  \n",
        "def compose(transforms_to_compose):\n",
        "    # combine all augmentations into one single pipeline\n",
        "    result = A.Compose([\n",
        "      item for sublist in transforms_to_compose for item in sublist\n",
        "    ])\n",
        "    return result\n",
        "\n",
        "\n",
        "def hard_transforms():\n",
        "    result = [\n",
        "      A.RandomRotate90(),\n",
        "      A.Cutout(),\n",
        "      A.RandomBrightnessContrast(\n",
        "          brightness_limit=0.2, contrast_limit=0.2, p=0.3\n",
        "      ),\n",
        "      A.GridDistortion(p=0.3),\n",
        "      A.HueSaturationValue(p=0.3)\n",
        "    ]\n",
        "\n",
        "    return result\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "mrRzeFppnPMt",
        "colab": {}
      },
      "source": [
        "train_transforms = compose([resize_transforms(), hard_transforms(), post_transforms()])\n",
        "valid_transforms = compose([pre_transforms(), post_transforms()])\n",
        "\n",
        "show_transforms = compose([resize_transforms(), hard_transforms()])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CTiQd-frN74v"
      },
      "source": [
        "You can restart the cell below to see more examples of augmentations."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "gGIbhBm1orXt",
        "colab": {}
      },
      "source": [
        "show_random(ALL_IMAGES, ALL_MASKS, transforms=show_transforms)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RvLlx1OvosGX"
      },
      "source": [
        "-------"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fGotRWGjyYOw"
      },
      "source": [
        "## Loaders"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WIfyQYAhY26C",
        "colab": {}
      },
      "source": [
        "import collections\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "def get_loaders(\n",
        "    images: List[Path],\n",
        "    masks: List[Path],\n",
        "    random_state: int,\n",
        "    valid_size: float = 0.2,\n",
        "    batch_size: int = 32,\n",
        "    num_workers: int = 4,\n",
        "    train_transforms_fn = None,\n",
        "    valid_transforms_fn = None,\n",
        ") -> dict:\n",
        "\n",
        "    indices = np.arange(len(images))\n",
        "\n",
        "    # Let's divide the data set into train and valid parts.\n",
        "    train_indices, valid_indices = train_test_split(\n",
        "      indices, test_size=valid_size, random_state=random_state, shuffle=True\n",
        "    )\n",
        "\n",
        "    np_images = np.array(images)\n",
        "    np_masks = np.array(masks)\n",
        "\n",
        "    # Creates our train dataset\n",
        "    train_dataset = SegmentationDataset(\n",
        "      images = np_images[train_indices].tolist(),\n",
        "      masks = np_masks[train_indices].tolist(),\n",
        "      transforms = train_transforms_fn\n",
        "    )\n",
        "\n",
        "    # Creates our valid dataset\n",
        "    valid_dataset = SegmentationDataset(\n",
        "      images = np_images[valid_indices].tolist(),\n",
        "      masks = np_masks[valid_indices].tolist(),\n",
        "      transforms = valid_transforms_fn\n",
        "    )\n",
        "\n",
        "    # Catalyst uses normal torch.data.DataLoader\n",
        "    train_loader = DataLoader(\n",
        "      train_dataset,\n",
        "      batch_size=batch_size,\n",
        "      shuffle=True,\n",
        "      num_workers=num_workers\n",
        "    )\n",
        "\n",
        "    valid_loader = DataLoader(\n",
        "      valid_dataset,\n",
        "      batch_size=batch_size,\n",
        "      shuffle=False,\n",
        "      num_workers=num_workers\n",
        "    )\n",
        "\n",
        "    # And excpect to get an OrderedDict of loaders\n",
        "    loaders = collections.OrderedDict()\n",
        "    loaders[\"train\"] = train_loader\n",
        "    loaders[\"valid\"] = valid_loader\n",
        "\n",
        "    return loaders"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "OG6_pgmkxk_b",
        "colab": {}
      },
      "source": [
        "if is_fp16_used:\n",
        "  batch_size = 64\n",
        "else:\n",
        "  batch_size = 32\n",
        "\n",
        "loaders = get_loaders(\n",
        "    images=ALL_IMAGES,\n",
        "    masks=ALL_MASKS,\n",
        "    random_state=SEED,\n",
        "    train_transforms_fn=train_transforms,\n",
        "    valid_transforms_fn=valid_transforms,\n",
        "    batch_size=batch_size\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5BF2Vgdly9RR"
      },
      "source": [
        "## Model\n",
        "\n",
        "The Catalyst has many segmentation models. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Zm7JsNrczOQG",
        "colab": {}
      },
      "source": [
        "from catalyst.contrib.models import ResnetFPNUnet\n",
        "\n",
        "model = ResnetFPNUnet(num_classes=1, arch=\"resnet18\", pretrained=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GXxJFBUkybYs"
      },
      "source": [
        "## Model training\n",
        "### Preparing training params.\n",
        "\n",
        "We will optimize loss as the sum of IoU, Dice and BCE, specifically this function: $IoU + Dice + 0.8*BCE$.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "nhVSEDGbY26G",
        "colab": {}
      },
      "source": [
        "from torch import nn\n",
        "\n",
        "from catalyst.contrib.criterion import DiceLoss, IoULoss\n",
        "\n",
        "# we have multiple criterions\n",
        "criterion = {\n",
        "    \"dice\": DiceLoss(),\n",
        "    \"iou\": IoULoss(),\n",
        "    \"bce\": nn.BCEWithLogitsLoss()\n",
        "}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "IirWWkf8PeXh",
        "colab": {}
      },
      "source": [
        "from torch import optim\n",
        "\n",
        "from catalyst.contrib.optimizers import RAdam, Lookahead\n",
        "\n",
        "learning_rate = 0.001\n",
        "encoder_learning_rate = 0.0005\n",
        "\n",
        "# Since we use a pre-trained encoder, we will reduce the learning rate on it.\n",
        "layerwise_params = {\"encoder*\": dict(lr=encoder_learning_rate, weight_decay=0.00003)}\n",
        "\n",
        "# This function removes weight_decay for biases and applies our layerwise_params\n",
        "model_params = utils.process_model_params(model, layerwise_params=layerwise_params)\n",
        "\n",
        "# Catalyst has new SOTA optimizers out of box\n",
        "base_optimizer = RAdam(model_params, lr=learning_rate, weight_decay=0.0003)\n",
        "optimizer = Lookahead(base_optimizer)\n",
        "\n",
        "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3pTmRiOfY26I",
        "colab": {}
      },
      "source": [
        "from catalyst.dl import SupervisedRunner\n",
        "\n",
        "num_epochs = 3\n",
        "logdir = \"./logs/segmentation\"\n",
        "\n",
        "device = utils.get_device()\n",
        "print(f\"device: {device}\")\n",
        "\n",
        "if is_fp16_used:\n",
        "  fp16_params = dict(opt_level=\"O1\") # params for FP16\n",
        "else:\n",
        "  fp16_params = None\n",
        "\n",
        "print(f\"FP16 params: {fp16_params}\")\n",
        "\n",
        "\n",
        "# by default SupervisedRunner uses \"features\" and \"targets\", in our case we get \"image\" and \"mask\" keys in dataset __getitem__\n",
        "runner = SupervisedRunner(device=device, input_key=\"image\", input_target_key=\"mask\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "n5mvw0fO1xI5"
      },
      "source": [
        "### Monitoring in tensorboard"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "MUiqd4dOSOOl"
      },
      "source": [
        "If you do not have a Tensorboard opened after you have run the cell below, try running the cell again."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "DnQZm-Ka1wdI",
        "colab": {}
      },
      "source": [
        "%tensorboard --logdir {logdir}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AuIsPPy1MrQG",
        "colab_type": "text"
      },
      "source": [
        "### Running the train-loop"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VqBlC5_iY26K",
        "colab": {}
      },
      "source": [
        "from catalyst.dl.callbacks import DiceCallback, IouCallback, \\\n",
        "  CriterionCallback, CriterionAggregatorCallback\n",
        "\n",
        "runner.train(\n",
        "    model=model,\n",
        "    criterion=criterion,\n",
        "    optimizer=optimizer,\n",
        "    scheduler=scheduler,\n",
        "    \n",
        "    # our dataloaders\n",
        "    loaders=loaders,\n",
        "    \n",
        "    callbacks=[\n",
        "        # Each criterion is calculated separately.\n",
        "        CriterionCallback(\n",
        "            input_key=\"mask\",\n",
        "            prefix=\"loss_dice\",\n",
        "            criterion_key=\"dice\"\n",
        "        ),\n",
        "        CriterionCallback(\n",
        "            input_key=\"mask\",\n",
        "            prefix=\"loss_iou\",\n",
        "            criterion_key=\"iou\"\n",
        "        ),\n",
        "        CriterionCallback(\n",
        "            input_key=\"mask\",\n",
        "            prefix=\"loss_bce\",\n",
        "            criterion_key=\"bce\",\n",
        "            multiplier=0.8\n",
        "        ),\n",
        "        \n",
        "        # And only then we aggregate everything into one loss.\n",
        "        CriterionAggregatorCallback(\n",
        "            prefix=\"loss\",\n",
        "            loss_keys=[\"loss_dice\", \"loss_iou\", \"loss_bce\"],\n",
        "            loss_aggregate_fn=\"sum\" # or \"mean\"\n",
        "        ),\n",
        "        \n",
        "        # metrics\n",
        "        DiceCallback(input_key=\"mask\"),\n",
        "        IouCallback(input_key=\"mask\"),\n",
        "    ],\n",
        "    # path to save logs\n",
        "    logdir=logdir,\n",
        "    \n",
        "    num_epochs=num_epochs,\n",
        "    \n",
        "    # save our best checkpoint by IoU metric\n",
        "    main_metric=\"iou\",\n",
        "    # IoU needs to be maximized.\n",
        "    minimize_metric=False,\n",
        "    \n",
        "    # for FP16. It uses the variable from the very first cell\n",
        "    fp16=fp16_params,\n",
        "    \n",
        "    # prints train logs\n",
        "    verbose=True\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tRH8jhg3Q_zB"
      },
      "source": [
        "## Model inference\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "XyugaCieRnx2",
        "colab": {}
      },
      "source": [
        "TEST_IMAGES = sorted(test_image_path.glob(\"*.jpg\"))\n",
        "len(TEST_IMAGES)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "TjyIYXUa1T9-",
        "colab": {}
      },
      "source": [
        "# create test dataset\n",
        "test_dataset = SegmentationDataset(\n",
        "    TEST_IMAGES, \n",
        "    transforms=valid_transforms\n",
        ")\n",
        "\n",
        "\n",
        "num_workers: int = 4\n",
        "\n",
        "infer_loader = DataLoader(\n",
        "    test_dataset,\n",
        "    batch_size=batch_size,\n",
        "    shuffle=False,\n",
        "    num_workers=num_workers\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "LY4TcgwVSBRN",
        "colab": {}
      },
      "source": [
        "\n",
        "# this get predictions for the whole loader\n",
        "predictions = runner.predict_loader(\n",
        "    model=model,\n",
        "    loader=infer_loader,\n",
        "    resume=f\"{logdir}/checkpoints/best.pth\",\n",
        "    verbose=True,\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WHkEcZxMUVKV"
      },
      "source": [
        "The prediction type is `np.ndarray`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "fFMg6UzCTj5W",
        "colab": {}
      },
      "source": [
        "print(type(predictions))\n",
        "print(predictions.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "c57ZjVTxT-9s",
        "colab": {}
      },
      "source": [
        "threshold = 0.5\n",
        "max_count = 5\n",
        "\n",
        "\n",
        "def detach(tensor: torch.Tensor) -> np.ndarray:\n",
        "    return tensor.detach().cpu().numpy()\n",
        "\n",
        "for i, (features, logits) in enumerate(zip(test_dataset, predictions)):\n",
        "    image = detach(features[\"image\"])\n",
        "    \n",
        "    \n",
        "    image = image.transpose(1, 2, 0)\n",
        "    mask_ = torch.from_numpy(logits[0]).sigmoid()\n",
        "    mask = detach(mask_ > threshold).astype(\"uint8\")\n",
        "        \n",
        "    show_examples(name=\"\", image=image, mask=mask)\n",
        "    \n",
        "    if i >= max_count:\n",
        "        break"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "n-B5OAzXsJkB"
      },
      "source": [
        "### Test-time augmentations (TTA)\n",
        "\n",
        "`ttach` is a new awesome library for test-time augmentation for segmentation or classification tasks"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "SnYuacitU_Kj",
        "colab": {}
      },
      "source": [
        "import ttach as tta\n",
        "\n",
        "# D4 makes horizontal and vertical flips + rotations for [0, 90, 180, 270] angels.\n",
        "# and then merges the result masks with merge_mode=\"mean\"\n",
        "tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode=\"mean\")\n",
        "\n",
        "tta_runner = SupervisedRunner(\n",
        "    model=tta_model,\n",
        "    device=utils.get_device(),\n",
        "    input_key=\"image\"\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8yeF2OAVs1xY",
        "colab": {}
      },
      "source": [
        "infer_loader = DataLoader(\n",
        "    test_dataset,\n",
        "    batch_size=1,\n",
        "    shuffle=False,\n",
        "    num_workers=num_workers\n",
        ")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "W6ZRf_VjvAFc",
        "colab": {}
      },
      "source": [
        "batch = next(iter(infer_loader))\n",
        "\n",
        "if is_fp16_used:\n",
        "  # Move our batch to FP16\n",
        "  batch[\"image\"] = batch[\"image\"].half()\n",
        "\n",
        "batch"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "TR4hElv9tNfI",
        "colab": {}
      },
      "source": [
        "# predict_batch will automatically move the batch to the Runner's device\n",
        "\n",
        "tta_predictions = tta_runner.predict_batch(batch)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "UjQG8lAgtjU5",
        "colab": {}
      },
      "source": [
        "tta_predictions"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "D1umd2aowdDm"
      },
      "source": [
        "Shape is `batch_size x channels x height x width`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "85BYUDQAvtyv",
        "colab": {}
      },
      "source": [
        "tta_predictions[\"logits\"].shape"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "h6VXwTBRwX1x"
      },
      "source": [
        "Let's see our mask after TTA"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "JSDjXaC5v86T",
        "colab": {}
      },
      "source": [
        "threshold = 0.5\n",
        "\n",
        "image = detach(batch[\"image\"][0]).astype(float)\n",
        "\n",
        "\n",
        "image = image.transpose(1, 2, 0)\n",
        "\n",
        "logits = tta_predictions[\"logits\"][0, 0].sigmoid()\n",
        "mask = detach(logits > threshold).astype(\"uint8\")\n",
        "\n",
        "show_examples(name=\"\", image=image, mask=mask)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "hBm6KJg8wTGM",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}